Add examples for MoE models - Mixtral in TE#2642
Conversation
Greptile SummaryThis PR adds a comprehensive tutorial for integrating HuggingFace Mixtral (MoE) with Transformer Engine, covering TE GroupedLinear, expert parallelism via NCCL all-to-all, quantized recipes, and a fused tier-9 path (
Confidence Score: 2/5Several forward-pass crashes remain unresolved across both model files; the PR is not ready to merge in its current state. The loop-mode EP path in te_mixtral.py crashes on every forward call when ep_size > 1 due to calling a nonexistent method on nn.Parameter. This compounds previously flagged but still-open bugs: the should_pack_inputs assertion fires on every decode step, input_ids.shape crashes when callers pass inputs_embeds inside an InferenceParams branch, and flash-attn is required but absent from the dependency file. te_mixtral.py (loop-mode EP forward crash + decode assertion + inputs_embeds guard) and te_mixtral_mxfp8.py (decode assertion + inputs_embeds guard) need the most attention before merge. Important Files Changed
Reviews (29): Last reviewed commit: "Merge branch 'add-moe-example' of github..." | Re-trigger Greptile |
|
Hi @sudhakarsingh27 can you check if this addresses your comments? Tested in 2x H100. |
Agreed! I think any example should show some perf gain and include the whole weight mapping so the user can run the example. |
|
Documentation build is not working, if you fix it please ping me and I'll review. |
d7301a6 to
fcc7e35
Compare
72e0e45 to
226173a
Compare
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
- Fix m_splits bug: use (selected_experts == i).sum() instead of .any(dim=-1).sum() * top_k, which caused dimension mismatches in GroupedLinear - Add te_mixtral.py: TEMixtralSparseMoeBlock (GroupedLinear + moe_permute/ unpermute), replace_moe_block context manager, TEMixtralForCausalLM with HF weight loading, and replace_params for expert weight packing - Add utils.py: HyperParameters, data loading, BF16/FP8 model init, Accelerate wrapping, fine-tuning loop — mirrors te_llama/utils.py style - Add requirements.txt matching te_llama versions - Expand notebook from bare code snippet to full tutorial covering: architecture overview, HF vs TE comparison table, unit-test cell, baseline BF16 run, BF16 TE improvement, FP8 TE improvement, expert routing/scaling discussion, generalisation guide for other MoE models Addresses reviewer feedback: fixes the critical runtime bug (greptile) and expands to tutorial quality comparable to the Llama/Gemma examples (sudhakarsingh27), covering the scope requested in issue NVIDIA#2573. Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
A few bugs found during code review:
- moe_permute and moe_unpermute were missing map_type='index'. The
selected_experts tensor is [num_tokens, top_k] indices, not a mask,
so without this the routing is completely wrong at runtime.
- num_out_tokens=None should be -1 (the API expects an int).
- moe_unpermute: replaced deprecated probs= with merging_probs= and
kept routing_weights in float32 as TE recommends.
- utils.py: num_warmup_steps was hardcoded to 100 instead of using
hyperparams.num_warmup_steps, which made the benchmark meaningless.
- requirements.txt: transformers==4.57.0 doesn't exist, fixed to 4.47.1.
- Notebook generalisation guide: updated code template with the same fixes.
Tested on 2xH100 in nvcr.io/nvidia/pytorch:26.01-py3 (PyTorch 2.10.0, CUDA 13.1):
$ python3 -c "
from te_mixtral import TEMixtralSparseMoeBlock
from transformers import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import torch
cfg = MixtralConfig(hidden_size=256, intermediate_size=512,
num_local_experts=4, num_experts_per_tok=2)
x = torch.randn(2, 8, cfg.hidden_size, device='cuda', dtype=torch.bfloat16)
hf_out, hf_logits = MixtralSparseMoeBlock(cfg).cuda().bfloat16()(x)
te_out, te_logits = TEMixtralSparseMoeBlock(cfg).cuda().bfloat16()(x)
assert hf_out.shape == te_out.shape
assert hf_logits.shape == te_logits.shape
print('PASS')"
Input shape : torch.Size([2, 8, 256])
Output shape : torch.Size([2, 8, 256]) (matches HF: True)
Logits shape : torch.Size([16, 4]) (matches HF: True)
Output dtype : torch.bfloat16
PASS
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
The docs CI was failing because nbsphinx auto mode tries to execute notebooks with no stored outputs. Add explicit execute:never metadata so the docs builder renders the notebook as-is without running cells. Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
…otebook te_mixtral.py: - Replace per-expert .item() loop with torch.bincount + .tolist() (8 GPU syncs → 1), eliminating the main cause of GroupedLinear speedup regression - Remove unnecessary num_out_tokens/max_token_num args from moe_permute - Fix replace_params to use .copy_() instead of fragile .data[] assignment, and load_state_dict from fully-populated te_state in one shot - Add device_map="auto" support in from_hf_model via accelerate.dispatch_model - Rename replace_params -> _pack_expert_weights for clarity tutorial_accelerate_hf_mixtral_with_te.ipynb: - Remove shape-check section (redundant) - Add FP8 prefill and decode-regime benchmark cells with summary table - Restructure training sections to match te_llama pattern: restart at top of each cell, combined hyperparams + init + wrap + finetune, concise markdown - Bump benchmark SEQ 512 -> 2048 for realistic H100 workload utils.py: - Apply user batch size and sequence length adjustments Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci
| @@ -0,0 +1,275 @@ | |||
| { | |||
There was a problem hiding this comment.
Notebook Tier bash commands misaligned with
run_finetune_ep.py improvement indices
The bash commands shown for Tiers 2–4 reference --improvement values that map to different configurations in run_finetune_ep.py than the notebook describes. Concretely:
| Notebook Tier | Described as | Bash uses | run_finetune_ep.py actually does |
|---|---|---|---|
| 2 | TE GroupedLinear EP | --improvement 2 |
expert_ffn_mode="loop" (Python F.linear, not GroupedLinear) |
| 3 | Fused DeepEP dispatcher (BF16) | --improvement 3 |
expert_ffn_mode="grouped_op", no DeepEP |
| 4 | MXFP8 precision | --improvement 4 |
DeepEP BF16 GroupedLinear, not MXFP8 |
Additionally, every multi-GPU bash command omits --ep-size 8, so the default --ep-size 2 (4 experts/rank) is used, whereas the Python cells in each tier all set hp.expert_parallel_size = 8 (1 expert/rank). The benchmark step-time numbers in the results tables were measured at EP=8, so users following the bash commands will reproduce different performance characteristics.
The correct mapping from notebook tier to --improvement flag appears to be: Tier 2 → --improvement 6, Tier 3 → --improvement 4, Tier 4 → --improvement 8 (all with --ep-size 8 added).
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
…ne into add-moe-example
for more information, see https://pre-commit.ci
| torch | ||
| torchao!=0.14.0 | ||
| transformer_engine[pytorch] | ||
|
|
||
| transformers | ||
| accelerate | ||
| datasets | ||
| safetensors | ||
| huggingface_hub | ||
| tokenizers |
There was a problem hiding this comment.
Both init_baseline_model and init_te_mixtral_model in utils.py explicitly set config._attn_implementation = "flash_attention_2". HuggingFace transformers validates this at from_pretrained time and raises ImportError if the flash-attn package is absent. Every tutorial user following the baseline or TE fine-tuning paths hits this error before the first training step. Add flash-attn to requirements.txt (the flash-attn PyPI package).
…to add-moe-example
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
| model.load_state_dict(te_state_dict, strict=False) | ||
| del hf_model |
There was a problem hiding this comment.
Silent weight-mapping failures in
init_te_mixtral_model
load_state_dict(strict=False) swallows all missing and unexpected keys, so if replace_params fails to map any parameter (e.g., due to a checkpoint key-name mismatch), the model silently trains with random initialization. test_accuracy.py shows the correct pattern: capture the return value and assert that the only missing keys are TE-internal _extra_state entries.
| model.load_state_dict(te_state_dict, strict=False) | |
| del hf_model | |
| missing, unexpected = model.load_state_dict(te_state_dict, strict=False) | |
| if unexpected: | |
| raise RuntimeError(f"Unexpected keys when loading TE state dict: {unexpected}") | |
| non_extra_missing = [k for k in missing if not k.endswith("_extra_state")] | |
| if non_extra_missing: | |
| raise RuntimeError(f"Missing non-extra-state keys in TE model: {non_extra_missing}") | |
| del hf_model |
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
| should_pack_inputs = not any(has_thd_input) and self.config.attn_input_format == "thd" | ||
|
|
||
| if should_pack_inputs: | ||
| assert ( | ||
| attention_mask is not None | ||
| ), "Attention mask is required when packing BSHD inputs." |
There was a problem hiding this comment.
Decode-step crash with default
attn_input_format="thd" when attention_mask is omitted
should_pack_inputs is True whenever attn_input_format == "thd" and no explicit THD kwargs are supplied, which is the normal case for every decode step during generation. In that path the assertion fires immediately because HF generate() typically does not propagate attention_mask on decode steps (sequence length 1). Adding a guard like past_key_values is None (or hidden_states.size(1) > 1) to should_pack_inputs would skip packing for single-token decode steps where the sequence is already effectively packed.
| routing_weights_for_unpermute = routing_weights | ||
| map_type = "index" | ||
|
|
||
| if self._ep_group is not None: |
There was a problem hiding this comment.
Non-standard sentinel for no-token-dropping in
moe_permute
The TE API documents -1 as the sentinel meaning "no token dropping"; the value used here (0) only happens to work because the fake/shape-inference path guards with num_out_tokens > 0, and 0 fails that check. If a future TE CUDA kernel instead treats 0 as a literal output-token count, permuted_hidden becomes empty and every all-to-all dispatch silently sends no tokens to any expert. Use -1 to match the documented contract.
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci
| if isinstance(past_key_values, InferenceParams): | ||
| lengths = ( | ||
| attention_mask.sum(dim=1).tolist() | ||
| if attention_mask.shape == input_ids.shape | ||
| else [1] * input_ids.shape[0] | ||
| ) | ||
| past_key_values.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths))) |
There was a problem hiding this comment.
input_ids crash when inputs_embeds is supplied
When a caller passes inputs_embeds instead of input_ids, input_ids is None. Both branches of the ternary — attention_mask.shape == input_ids.shape and [1] * input_ids.shape[0] — then raise AttributeError: 'NoneType' object has no attribute 'shape'. The fix is to derive the reference tensor from whichever of input_ids/inputs_embeds is non-None, mirroring the guard already needed in the parent te_mixtral.py.
| if should_pack_inputs: | ||
| assert attention_mask is not None, "attention_mask required when packing BSHD." |
There was a problem hiding this comment.
Assertion crash during auto-regressive decode
should_pack_inputs is True whenever attn_input_format == "thd" and no THD kwargs are already present, which is the default for every decode step in model.generate(). On decode steps HuggingFace typically does not forward attention_mask, so this assertion fires immediately: AssertionError: attention_mask required when packing BSHD. A guard like past_key_values is None (or hidden_states.size(1) > 1) would skip packing for single-token steps where packing is unnecessary.
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
…ne into add-moe-example
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
for more information, see https://pre-commit.ci
|
Added accuracy test between TE and HF - passed |
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
Signed-off-by: Faradawn Yang <73060648+faradawn@users.noreply.github.com>
…ne into add-moe-example
| if isinstance(gate_up_w, DTensor) or isinstance(gate_up_w.data, DTensor): | ||
| gate_up_w = gate_up_w.to_local() | ||
| down_w = self.experts_down_weight | ||
| if isinstance(down_w, DTensor) or isinstance(down_w.data, DTensor): | ||
| down_w = down_w.to_local() |
There was a problem hiding this comment.
to_local() called on nn.Parameter, not on DTensor — crashes under EP
After set_ep_group(), self.experts_gate_up_weight is nn.Parameter(DTensor.from_local(...)). The check isinstance(gate_up_w.data, DTensor) correctly identifies this case, but then calls gate_up_w.to_local() on the nn.Parameter wrapper. nn.Parameter is a separate torch.Tensor subclass and does not define to_local(), so every forward pass in loop mode with ep_size > 1 raises AttributeError: 'Parameter' object has no attribute 'to_local'.
The in-code comment correctly explains that .data must be avoided to preserve the autograd graph, but the fix cannot be simply calling to_local() on the Parameter either. One approach that preserves the graph is to re-wrap the local shard: gate_up_w = type(gate_up_w)(gate_up_w.data.to_local()). The same issue applies to down_w on lines 710-711.
Summary
This PR adds a complete tutorial for integrating HuggingFace Mixtral (MoE) with Transformer Engine, addressing the gap identified in #2573.
What's included
te_mixtral.py— Drop-inTEMixtralSparseMoeBlockthat replaces HF's loop-over-experts with TE'sGroupedLinear(batched GEMM) +moe_permute/moe_unpermute. Includesreplace_moe_blockcontext manager,TEMixtralForCausalLMwith HF weight loading, andreplace_paramsfor expert weight packing.utils.py— Data loading, BF16/FP8 model init, Accelerate wrapping, fine-tuning loop — mirrorste_llama/utils.pystyle.requirements.txt— Pinned dependencies matching the Llama/Gemma tutorials.te_llamaandte_gemma, covering:Bug fix
Corrected the
m_splitscalculation flagged by the automated review:Scope
Covers all topics requested in #2573: